k-最近傍法(k-Nearest Neighbor, k-NN)
Overview
k-最近傍法(k-NN)アルゴリズムは、最も単純な学習アルゴリズムであると言われています。訓練データセットを学習し、予測を行う際には、訓練データセットの中から一番近い点、つまり「最近傍点」を見つけます。近傍点は1つとは限らず、任意個(k個)の近傍点を考慮することもできます。
下で図示しているのは2クラス分類の場合ですが、任意のクラス数に対しても適用できます。それぞれのクラスに対して近傍点がいくつあるかを数えて、最も多いクラスを予測値とします。また、分類問題だけでなく、回帰問題に用いることもできます。
1-最近傍法/3-最近傍法
https://gyazo.com/54602f1d38796c9c1760795bcb3b374bhttps://gyazo.com/b1f95d8ffbda9952ebc3c82cafbe4952
Theory
距離の算出には、一般的にユークリッド距離やマンハッタン距離が使われます。
ユークリッド距離の数式
$ d = \sqrt{(b_1 - a_1)^2 + (b_2 - a_2)^2}
マンハッタン距離の数式
$ d = |(b_1 - a_1)| + |(b_2 - a_2)|
ユークリッド距離/マンハッタン距離
https://gyazo.com/59dace33c3f9c1236e8633000d928432
Coding(Classification)
forgeデータセットでモデルを構築・学習・予測・評価して決定境界を描画する
code: Python
import numpy as np
import matplotlib.pyplot as plt
import mglearn
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
X, y = mglearn.datasets.make_forge()
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf = KNeighborsClassifier(n_neighbors=3)
clf.fit(X_train, y_train)
print('Test set predictions: {}'.format(clf.predict(X_test)))
print('Test set accuracy: {:.2f}'.format(clf.score(X_test, y_test)))
# 決定境界を描画する
fig, axes = plt.subplots(1, 3, figsize=(10, 3))
for n_neighbors, ax in zip(1, 3, 9, axes): # axesはAxesオブジェクトの1x3の配列 # fitメソッドは自分自身を返すので、1行でインスタンスを生成してfitできる
clf = KNeighborsClassifier(n_neighbors=n_neighbors).fit(X, y)
mglearn.plots.plot_2d_separator(clf, X, fill=True, eps=0.5, ax=ax, alpha=.4)
mglearn.discrete_scatter(X:, 0, X:, 1, y, ax=ax) ax.set_title('{} neighbors(s)'.format(n_neighbors))
ax.set_xlabel('feature 0')
ax.set_ylabel('feature 1')
plt.show()
--------------------------------------------------------------------------
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
metric_params=None, n_jobs=1, n_neighbors=3, p=2,
weights='uniform')
Test set accuracy: 0.86
---------------------------------------------------------------------------
https://gyazo.com/6ae6546617d68acdc51081b7b78f1888
左の図からわかるように、1つの最近傍点のみを用いると、決定境界は、訓練データに近くなります。つまり、最近傍点が少ない場合は複雑度の高いモデルに対応し、最近傍点が多い場合は複雑度の低いモデルに対応します。すなわち、最近傍点が多いほうがより汎化性能がありそうだと分かります。
cancerデータセットでモデルを構築・学習・評価してモデルの複雑さと汎化性能の関係を調べる
code: Python
import numpy as np
import matplotlib.pyplot as plt
import mglearn
from sklearn.datasets import load_breast_cancer
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
cancer = load_breast_cancer()
X_train, X_test, y_train, y_test = train_test_split(cancer.data, cancer.target, stratify=cancer.target, random_state=66)
training_accuracy = []
test_accuracy = []
# n_neighborsを1から10まで試す
neighbors_settings = range(1, 11)
for n_neighbors in neighbors_settings:
# モデルを構築
clf = KNeighborsClassifier(n_neighbors=n_neighbors)
clf.fit(X_train, y_train)
training_accuracy.append(clf.score(X_train, y_train))
test_accuracy.append(clf.score(X_test, y_test))
plt.plot(neighbors_settings, training_accuracy, label='training accuracy')
plt.plot(neighbors_settings, test_accuracy, label='test accuracy', linestyle='--')
plt.ylabel('Accuracy')
plt.xlabel('n_neighbors')
plt.legend()
plt.show()
https://gyazo.com/159b30a392e26c782cd7e7a28aaece12
最近傍点が少ない場合、訓練データセットには高い精度を示しますが、テストデートセットの精度は低いです。(過剰適合)
最近傍点がある程度多い場合、訓練データセット、テストデータセットともに高い精度が出ています。
最近傍点が多すぎる場合、訓練データセット、テストデータセットともに精度が低いです。(適合不足)
Coding(Regression)
waveデータセットでモデルを構築・学習・予測・評価して全てのデータポイントに対する予測値を調べる
code: Python
import numpy as np
import matplotlib.pyplot as plt
import mglearn
from sklearn.neighbors import KNeighborsRegressor
from sklearn.model_selection import train_test_split
X, y = mglearn.datasets.make_wave(n_samples=40)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
# 3つの最近傍点を考慮するように設定してモデルのインスタンスを生成
reg = KNeighborsRegressor(n_neighbors=3)
reg.fit(X_train, y_train)
print('Test set predictions:\n{}'.format(reg.predict(X_test)))
# 予測値を描画する
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
# -3から3までの間に1000点のデータポイントを作る
line = np.linspace(-3, 3, 1000).reshape(-1, 1)
for n_neighbors, ax in zip(1, 3, 9, axes): # 1, 3, 9近傍点で予測
reg = KNeighborsRegressor(n_neighbors=n_neighbors)
reg.fit(X_train, y_train)
ax.plot(line, reg.predict(line))
ax.plot(X_train, y_train, '^', c=mglearn.cm2(0), markersize=8)
ax.plot(X_test, y_test, 'v', c=mglearn.cm2(1), markersize=8)
ax.set_title('{} neighbors(s)\n train score: {:.2f} test score: {:.2f}'.format(
n_neighbors, reg.score(X_train, y_train), reg.score(X_test, y_test)
))
ax.set_xlabel('Feature')
ax.set_ylabel('Target')
plt.show()
--------------------------------------------------------------------------
Test set predictions:
[-0.05396539 0.35686046 1.13671923 -1.89415682 -1.13881398 -1.63113382
0.35686046 0.91241374 -0.44680446 -1.13881398]
Test set R^2: 0.83
--------------------------------------------------------------------------
https://gyazo.com/4538269b546e3c74a9a24a9e445c776e
このグラフからわかるように、近傍点が1つの場合には、予測値がすべて訓練データポイントを通っており、非常に不安定な予測になっています。また、考慮する近傍点を増やしていくと、予測は滑らかになりますが、訓練データセットに対する適合度は下がっていきます。
Summary
Merit
モデルが理解しやすい
パラメータをあまり調整しなくてもいい
Demerit
多数の特徴量(数百以上)を持つデータセットではうまく機能しない
ほとんどの特徴量が0となるような疎なデータセットに対しての性能が悪い
実務ではほとんど使われていない
Parameters
近傍点の数
3や5程度の小さな数で十分な場合がほとんど
データポイント間の距離測度
デフォルトでユークリッド距離を用いる(ほとんどの場合これでうまくいく)